5ba634
@@ -17,18 +17,10 @@
  */
 package org.apache.hadoop.hive.ql.optimizer.calcite.translator;
 
-import java.math.BigDecimal;
-import java.math.BigInteger;
-import java.sql.Timestamp;
-import java.time.Instant;
-import java.util.ArrayList;
-import java.util.Calendar;
-import java.util.Date;
-import java.util.LinkedHashMap;
-import java.util.List;
-import java.util.Locale;
-import java.util.Map;
-
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableList.Builder;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Iterables;
 import org.apache.calcite.avatica.util.TimeUnit;
 import org.apache.calcite.avatica.util.TimeUnitRange;
 import org.apache.calcite.plan.RelOptCluster;
@@ -49,6 +41,7 @@
 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.sql.parser.SqlParserPos;
 import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.sql.type.SqlTypeUtil;
 import org.apache.calcite.util.ConversionUtil;
 import org.apache.calcite.util.DateString;
 import org.apache.calcite.util.NlsString;
@@ -104,9 +97,17 @@
 import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
 import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
 
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableList.Builder;
-import com.google.common.collect.ImmutableMap;
+import java.math.BigDecimal;
+import java.math.BigInteger;
+import java.sql.Timestamp;
+import java.time.Instant;
+import java.util.ArrayList;
+import java.util.Calendar;
+import java.util.Date;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
 
 public class RexNodeConverter {
 
@@ -455,26 +456,50 @@
private RexNode handleExplicitCast(ExprNodeGenericFuncDesc func, List<RexNode> c
 
   private List<RexNode> rewriteExtractDateChildren(SqlOperator op, List<RexNode> childRexNodeLst)
       throws SemanticException {
-    List<RexNode> newChildRexNodeLst = new ArrayList<RexNode>();
+    List<RexNode> newChildRexNodeLst = new ArrayList<>(2);
+    final boolean isTimestampLevel;
     if (op == HiveExtractDate.YEAR) {
       newChildRexNodeLst.add(cluster.getRexBuilder().makeFlag(TimeUnitRange.YEAR));
+      isTimestampLevel = false;
     } else if (op == HiveExtractDate.QUARTER) {
       newChildRexNodeLst.add(cluster.getRexBuilder().makeFlag(TimeUnitRange.QUARTER));
+      isTimestampLevel = false;
     } else if (op == HiveExtractDate.MONTH) {
       newChildRexNodeLst.add(cluster.getRexBuilder().makeFlag(TimeUnitRange.MONTH));
+      isTimestampLevel = false;
     } else if (op == HiveExtractDate.WEEK) {
       newChildRexNodeLst.add(cluster.getRexBuilder().makeFlag(TimeUnitRange.WEEK));
+      isTimestampLevel = false;
     } else if (op == HiveExtractDate.DAY) {
       newChildRexNodeLst.add(cluster.getRexBuilder().makeFlag(TimeUnitRange.DAY));
+      isTimestampLevel = false;
     } else if (op == HiveExtractDate.HOUR) {
       newChildRexNodeLst.add(cluster.getRexBuilder().makeFlag(TimeUnitRange.HOUR));
+      isTimestampLevel = true;
     } else if (op == HiveExtractDate.MINUTE) {
       newChildRexNodeLst.add(cluster.getRexBuilder().makeFlag(TimeUnitRange.MINUTE));
+      isTimestampLevel = true;
     } else if (op == HiveExtractDate.SECOND) {
       newChildRexNodeLst.add(cluster.getRexBuilder().makeFlag(TimeUnitRange.SECOND));
+      isTimestampLevel = true;
+    } else {
+      isTimestampLevel = false;
     }
-    assert childRexNodeLst.size() == 1;
-    newChildRexNodeLst.add(childRexNodeLst.get(0));
+
+    final RexNode child = Iterables.getOnlyElement(childRexNodeLst);
+    if (SqlTypeUtil.isDatetime(child.getType()) || SqlTypeUtil.isInterval(child.getType())) {
+      newChildRexNodeLst.add(child);
+    } else {
+      // We need to add a cast to DATETIME Family
+      if (isTimestampLevel) {
+        newChildRexNodeLst.add(
+            cluster.getRexBuilder().makeCast(cluster.getTypeFactory().createSqlType(SqlTypeName.TIMESTAMP), child));
+      } else {
+        newChildRexNodeLst.add(
+            cluster.getRexBuilder().makeCast(cluster.getTypeFactory().createSqlType(SqlTypeName.DATE), child));
+      }
+    }
+
     return newChildRexNodeLst;
   }
 
